#
#  Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#

# from beartype import BeartypeConf
# from beartype.claw import beartype_all  # <-- you didn't sign up for this
# beartype_all(conf=BeartypeConf(violation_type=UserWarning))    # <-- emit warnings from all code


import copy
import faulthandler
import logging
import os
import signal
import sys
import threading
import time
import traceback
from datetime import datetime, timezone
from typing import Any

import trio

from api.db.services.connector_service import ConnectorService, SyncLogsService
from api.db.services.knowledgebase_service import KnowledgebaseService
from common import settings
from common.config_utils import show_configs
from common.constants import FileSource, TaskStatus
from common.data_source import (
    BlobStorageConnector,
    DiscordConnector,
    GoogleDriveConnector,
    JiraConnector,
    NotionConnector,
)
from common.data_source.config import INDEX_BATCH_SIZE
from common.data_source.confluence_connector import ConfluenceConnector
from common.data_source.interfaces import CheckpointOutputWrapper
from common.data_source.utils import load_all_docs_from_checkpoint_connector
from common.log_utils import init_root_logger
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
from common.versions import get_ragflow_version

MAX_CONCURRENT_TASKS = int(os.environ.get("MAX_CONCURRENT_TASKS", "5"))
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS)


class SyncBase:
    SOURCE_NAME: str = None

    def __init__(self, conf: dict) -> None:
        self.conf = conf

    async def __call__(self, task: dict):
        SyncLogsService.start(task["id"], task["connector_id"])
        try:
            async with task_limiter:
                with trio.fail_after(task["timeout_secs"]):
                    document_batch_generator = await self._generate(task)
                    doc_num = 0
                    next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
                    if task["poll_range_start"]:
                        next_update = task["poll_range_start"]
                    for document_batch in document_batch_generator:
                        if not document_batch:
                            continue
                        min_update = min([doc.doc_updated_at for doc in document_batch])
                        max_update = max([doc.doc_updated_at for doc in document_batch])
                        next_update = max([next_update, max_update])
                        docs = [
                            {
                                "id": doc.id,
                                "connector_id": task["connector_id"],
                                "source": self.SOURCE_NAME,
                                "semantic_identifier": doc.semantic_identifier,
                                "extension": doc.extension,
                                "size_bytes": doc.size_bytes,
                                "doc_updated_at": doc.doc_updated_at,
                                "blob": doc.blob,
                            }
                            for doc in document_batch
                        ]

                        e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
                        err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{self.SOURCE_NAME}/{task['connector_id']}", task["auto_parse"])
                        SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err))
                        doc_num += len(docs)

                    prefix = "[Jira] " if self.SOURCE_NAME == FileSource.JIRA else ""
                    logging.info(f"{prefix}{doc_num} docs synchronized till {next_update}")
                    SyncLogsService.done(task["id"], task["connector_id"])
                    task["poll_range_start"] = next_update

        except Exception as ex:
            msg = "\n".join(["".join(traceback.format_exception_only(None, ex)).strip(), "".join(traceback.format_exception(None, ex, ex.__traceback__)).strip()])
            SyncLogsService.update_by_id(task["id"], {"status": TaskStatus.FAIL, "full_exception_trace": msg, "error_msg": str(ex)})

        SyncLogsService.schedule(task["connector_id"], task["kb_id"], task["poll_range_start"])

    async def _generate(self, task: dict):
        raise NotImplementedError


class S3(SyncBase):
    SOURCE_NAME: str = FileSource.S3

    async def _generate(self, task: dict):
        self.connector = BlobStorageConnector(bucket_type=self.conf.get("bucket_type", "s3"), bucket_name=self.conf["bucket_name"], prefix=self.conf.get("prefix", ""))
        self.connector.load_credentials(self.conf["credentials"])
        document_batch_generator = (
            self.connector.load_from_state()
            if task["reindex"] == "1" or not task["poll_range_start"]
            else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
        )

        begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
        logging.info("Connect to {}: {}(prefix/{}) {}".format(self.conf.get("bucket_type", "s3"), self.conf["bucket_name"], self.conf.get("prefix", ""), begin_info))
        return document_batch_generator


class Confluence(SyncBase):
    SOURCE_NAME: str = FileSource.CONFLUENCE

    async def _generate(self, task: dict):
        from common.data_source.config import DocumentSource
        from common.data_source.interfaces import StaticCredentialsProvider

        self.connector = ConfluenceConnector(
            wiki_base=self.conf["wiki_base"],
            space=self.conf.get("space", ""),
            is_cloud=self.conf.get("is_cloud", True),
            # page_id=self.conf.get("page_id", ""),
        )

        credentials_provider = StaticCredentialsProvider(tenant_id=task["tenant_id"], connector_name=DocumentSource.CONFLUENCE, credential_json=self.conf["credentials"])
        self.connector.set_credentials_provider(credentials_provider)

        # Determine the time range for synchronization based on reindex or poll_range_start
        if task["reindex"] == "1" or not task["poll_range_start"]:
            start_time = 0.0
            begin_info = "totally"
        else:
            start_time = task["poll_range_start"].timestamp()
            begin_info = f"from {task['poll_range_start']}"

        end_time = datetime.now(timezone.utc).timestamp()

        document_generator = load_all_docs_from_checkpoint_connector(
            connector=self.connector,
            start=start_time,
            end=end_time,
        )

        logging.info("Connect to Confluence: {} {}".format(self.conf["wiki_base"], begin_info))
        return [document_generator]


class Notion(SyncBase):
    SOURCE_NAME: str = FileSource.NOTION

    async def _generate(self, task: dict):
        self.connector = NotionConnector(root_page_id=self.conf["root_page_id"])
        self.connector.load_credentials(self.conf["credentials"])
        document_generator = (
            self.connector.load_from_state()
            if task["reindex"] == "1" or not task["poll_range_start"]
            else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
        )

        begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
        logging.info("Connect to Notion: root({}) {}".format(self.conf["root_page_id"], begin_info))
        return document_generator


class Discord(SyncBase):
    SOURCE_NAME: str = FileSource.DISCORD

    async def _generate(self, task: dict):
        server_ids: str | None = self.conf.get("server_ids", None)
        # "channel1,channel2"
        channel_names: str | None = self.conf.get("channel_names", None)

        self.connector = DiscordConnector(
            server_ids=server_ids.split(",") if server_ids else [],
            channel_names=channel_names.split(",") if channel_names else [],
            start_date=datetime(1970, 1, 1, tzinfo=timezone.utc).strftime("%Y-%m-%d"),
            batch_size=self.conf.get("batch_size", 1024),
        )
        self.connector.load_credentials(self.conf["credentials"])
        document_generator = (
            self.connector.load_from_state()
            if task["reindex"] == "1" or not task["poll_range_start"]
            else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
        )

        begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
        logging.info("Connect to Discord: servers({}),  channel({}) {}".format(server_ids, channel_names, begin_info))
        return document_generator


class Gmail(SyncBase):
    SOURCE_NAME: str = FileSource.GMAIL

    async def _generate(self, task: dict):
        pass


class GoogleDrive(SyncBase):
    SOURCE_NAME: str = FileSource.GOOGLE_DRIVE

    async def _generate(self, task: dict):
        connector_kwargs = {
            "include_shared_drives": self.conf.get("include_shared_drives", False),
            "include_my_drives": self.conf.get("include_my_drives", False),
            "include_files_shared_with_me": self.conf.get("include_files_shared_with_me", False),
            "shared_drive_urls": self.conf.get("shared_drive_urls"),
            "my_drive_emails": self.conf.get("my_drive_emails"),
            "shared_folder_urls": self.conf.get("shared_folder_urls"),
            "specific_user_emails": self.conf.get("specific_user_emails"),
            "batch_size": self.conf.get("batch_size", INDEX_BATCH_SIZE),
        }
        self.connector = GoogleDriveConnector(**connector_kwargs)
        self.connector.set_allow_images(self.conf.get("allow_images", False))

        credentials = self.conf.get("credentials")
        if not credentials:
            raise ValueError("Google Drive connector is missing credentials.")

        new_credentials = self.connector.load_credentials(credentials)
        if new_credentials:
            self._persist_rotated_credentials(task["connector_id"], new_credentials)

        if task["reindex"] == "1" or not task["poll_range_start"]:
            start_time = 0.0
            begin_info = "totally"
        else:
            start_time = task["poll_range_start"].timestamp()
            begin_info = f"from {task['poll_range_start']}"

        end_time = datetime.now(timezone.utc).timestamp()
        raw_batch_size = self.conf.get("sync_batch_size") or self.conf.get("batch_size") or INDEX_BATCH_SIZE
        try:
            batch_size = int(raw_batch_size)
        except (TypeError, ValueError):
            batch_size = INDEX_BATCH_SIZE
        if batch_size <= 0:
            batch_size = INDEX_BATCH_SIZE

        def document_batches():
            checkpoint = self.connector.build_dummy_checkpoint()
            pending_docs = []
            iterations = 0
            iteration_limit = 100_000

            while checkpoint.has_more:
                wrapper = CheckpointOutputWrapper()
                doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint))
                for document, failure, next_checkpoint in doc_generator:
                    if failure is not None:
                        logging.warning("Google Drive connector failure: %s", getattr(failure, "failure_message", failure))
                        continue
                    if document is not None:
                        pending_docs.append(document)
                        if len(pending_docs) >= batch_size:
                            yield pending_docs
                            pending_docs = []
                    if next_checkpoint is not None:
                        checkpoint = next_checkpoint

                iterations += 1
                if iterations > iteration_limit:
                    raise RuntimeError("Too many iterations while loading Google Drive documents.")

            if pending_docs:
                yield pending_docs

        try:
            admin_email = self.connector.primary_admin_email
        except RuntimeError:
            admin_email = "unknown"
        logging.info(f"Connect to Google Drive as {admin_email} {begin_info}")
        return document_batches()

    def _persist_rotated_credentials(self, connector_id: str, credentials: dict[str, Any]) -> None:
        try:
            updated_conf = copy.deepcopy(self.conf)
            updated_conf["credentials"] = credentials
            ConnectorService.update_by_id(connector_id, {"config": updated_conf})
            self.conf = updated_conf
            logging.info("Persisted refreshed Google Drive credentials for connector %s", connector_id)
        except Exception:
            logging.exception("Failed to persist refreshed Google Drive credentials for connector %s", connector_id)


class Jira(SyncBase):
    SOURCE_NAME: str = FileSource.JIRA

    async def _generate(self, task: dict):
        connector_kwargs = {
            "jira_base_url": self.conf["base_url"],
            "project_key": self.conf.get("project_key"),
            "jql_query": self.conf.get("jql_query"),
            "batch_size": self.conf.get("batch_size", INDEX_BATCH_SIZE),
            "include_comments": self.conf.get("include_comments", True),
            "include_attachments": self.conf.get("include_attachments", False),
            "labels_to_skip": self._normalize_list(self.conf.get("labels_to_skip")),
            "comment_email_blacklist": self._normalize_list(self.conf.get("comment_email_blacklist")),
            "scoped_token": self.conf.get("scoped_token", False),
            "attachment_size_limit": self.conf.get("attachment_size_limit"),
            "timezone_offset": self.conf.get("timezone_offset"),
        }

        self.connector = JiraConnector(**connector_kwargs)

        credentials = self.conf.get("credentials")
        if not credentials:
            raise ValueError("Jira connector is missing credentials.")

        self.connector.load_credentials(credentials)
        self.connector.validate_connector_settings()

        if task["reindex"] == "1" or not task["poll_range_start"]:
            start_time = 0.0
            begin_info = "totally"
        else:
            start_time = task["poll_range_start"].timestamp()
            begin_info = f"from {task['poll_range_start']}"

        end_time = datetime.now(timezone.utc).timestamp()

        raw_batch_size = self.conf.get("sync_batch_size") or self.conf.get("batch_size") or INDEX_BATCH_SIZE
        try:
            batch_size = int(raw_batch_size)
        except (TypeError, ValueError):
            batch_size = INDEX_BATCH_SIZE
        if batch_size <= 0:
            batch_size = INDEX_BATCH_SIZE

        def document_batches():
            checkpoint = self.connector.build_dummy_checkpoint()
            pending_docs = []
            iterations = 0
            iteration_limit = 100_000

            while checkpoint.has_more:
                wrapper = CheckpointOutputWrapper()
                generator = wrapper(
                    self.connector.load_from_checkpoint(
                        start_time,
                        end_time,
                        checkpoint,
                    )
                )
                for document, failure, next_checkpoint in generator:
                    if failure is not None:
                        logging.warning(
                            f"[Jira] Jira connector failure: {getattr(failure, 'failure_message', failure)}"
                        )
                        continue
                    if document is not None:
                        pending_docs.append(document)
                        if len(pending_docs) >= batch_size:
                            yield pending_docs
                            pending_docs = []
                    if next_checkpoint is not None:
                        checkpoint = next_checkpoint

                iterations += 1
                if iterations > iteration_limit:
                    logging.error(f"[Jira] Task {task.get('id')} exceeded iteration limit ({iteration_limit}).")
                    raise RuntimeError("Too many iterations while loading Jira documents.")

            if pending_docs:
                yield pending_docs

        logging.info(f"[Jira] Connect to Jira {connector_kwargs['jira_base_url']} {begin_info}")
        return document_batches()

    @staticmethod
    def _normalize_list(values: Any) -> list[str] | None:
        if values is None:
            return None
        if isinstance(values, str):
            values = [item.strip() for item in values.split(",")]
        return [str(value).strip() for value in values if value is not None and str(value).strip()]


class SharePoint(SyncBase):
    SOURCE_NAME: str = FileSource.SHAREPOINT

    async def _generate(self, task: dict):
        pass


class Slack(SyncBase):
    SOURCE_NAME: str = FileSource.SLACK

    async def _generate(self, task: dict):
        pass


class Teams(SyncBase):
    SOURCE_NAME: str = FileSource.TEAMS

    async def _generate(self, task: dict):
        pass


func_factory = {
    FileSource.S3: S3,
    FileSource.NOTION: Notion,
    FileSource.DISCORD: Discord,
    FileSource.CONFLUENCE: Confluence,
    FileSource.GMAIL: Gmail,
    FileSource.GOOGLE_DRIVE: GoogleDrive,
    FileSource.JIRA: Jira,
    FileSource.SHAREPOINT: SharePoint,
    FileSource.SLACK: Slack,
    FileSource.TEAMS: Teams,
}


async def dispatch_tasks():
    async with trio.open_nursery() as nursery:
        while True:
            try:
                list(SyncLogsService.list_sync_tasks()[0])
                break
            except Exception as e:
                logging.warning(f"DB is not ready yet: {e}")
                await trio.sleep(3)

        for task in SyncLogsService.list_sync_tasks()[0]:
            if task["poll_range_start"]:
                task["poll_range_start"] = task["poll_range_start"].astimezone(timezone.utc)
            if task["poll_range_end"]:
                task["poll_range_end"] = task["poll_range_end"].astimezone(timezone.utc)
            func = func_factory[task["source"]](task["config"])
            nursery.start_soon(func, task)
    await trio.sleep(1)


stop_event = threading.Event()


def signal_handler(sig, frame):
    logging.info("Received interrupt signal, shutting down...")
    stop_event.set()
    time.sleep(1)
    sys.exit(0)


CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
CONSUMER_NAME = "data_sync_" + CONSUMER_NO


async def main():
    logging.info(r"""
  _____        _           _____
 |  __ \      | |         / ____|
 | |  | | __ _| |_ __ _  | (___  _   _ _ __   ___
 | |  | |/ _` | __/ _` |  \___ \| | | | '_ \ / __|
 | |__| | (_| | || (_| |  ____) | |_| | | | | (__
 |_____/ \__,_|\__\__,_| |_____/ \__, |_| |_|\___|
                                  __/ |
                                 |___/
    """)
    logging.info(f"RAGFlow version: {get_ragflow_version()}")
    show_configs()
    settings.init_settings()
    if sys.platform != "win32":
        signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot)
        signal.signal(signal.SIGUSR2, stop_tracemalloc)
    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)

    while not stop_event.is_set():
        await dispatch_tasks()
    logging.error("BUG!!! You should not reach here!!!")


if __name__ == "__main__":
    faulthandler.enable()
    init_root_logger(CONSUMER_NAME)
    trio.run(main)
